{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "48863301-737d-403b-a0c5-e0ab8ece9e84",
   "metadata": {},
   "source": [
    "# TorchXrayVision"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38ceaee7-56ea-4c71-8834-caa5a96799c1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ebdd0c4-c9a5-4a9b-9904-760619fed935",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "!pip install numpy matplotlib torch torchvision torchxrayvision "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7efaa1dd-933f-4cd2-86ef-4c8aafbc16cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import skimage\n",
    "import torch\n",
    "import torchvision\n",
    "import matplotlib.pyplot as plt\n",
    "import torchxrayvision as xrv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d1e8e99-9eb1-44cc-b4f9-da0913ce79fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = xrv.baseline_models.chestx_det.PSPNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e854dd80-98d6-4e82-accf-dc8d218aeef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c4b7581-39e9-45a4-87ea-07d51cfd9270",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Single Image Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed4ec6b0-9289-48f1-a20f-9a334444d692",
   "metadata": {},
   "outputs": [],
   "source": [
    "import skimage.io\n",
    "import torch\n",
    "import torchxrayvision as xrv\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "# Load the grayscale image\n",
    "img = skimage.io.imread(\"path/to/image\")\n",
    "\n",
    "# Normalize the image to [-1024, 1024] range\n",
    "img = xrv.datasets.normalize(img, 255)\n",
    "\n",
    "# Since the image is already grayscale, simply add the channel dimension\n",
    "img = img[None, ...]\n",
    "\n",
    "# Apply the transformations\n",
    "transform = transforms.Compose([\n",
    "    xrv.datasets.XRayCenterCrop(),\n",
    "    xrv.datasets.XRayResizer(512),\n",
    "])\n",
    "\n",
    "img = transform(img)\n",
    "img = torch.from_numpy(img)\n",
    "\n",
    "print(img.shape)  # Verify the shape of the tensor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db81901e-e068-4e8d-a117-e8f03305a0fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    pred = model(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa283e74-2b1f-4ae0-b53e-5d9e67076516",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (26,5))\n",
    "plt.subplot(1, len(model.targets) + 1, 1)\n",
    "plt.imshow(img[0], cmap='gray')\n",
    "for i in range(len(model.targets)):\n",
    "    plt.subplot(1, len(model.targets) + 1, i+2)\n",
    "    plt.imshow(pred[0, i])\n",
    "    plt.title(model.targets[i])\n",
    "    plt.axis('off')\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79a89bc4-ca29-4d62-a84d-70fa913cc4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = 1 / (1 + np.exp(-pred))  # sigmoid\n",
    "pred[pred < 0.5] = 0\n",
    "pred[pred > 0.5] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "369cae9b-a72b-4fc7-8d9c-5b269d7d45b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (26,5))\n",
    "plt.subplot(1, len(model.targets) + 1, 1)\n",
    "plt.imshow(img[0], cmap='gray')\n",
    "for i in range(len(model.targets)):\n",
    "    plt.subplot(1, len(model.targets) + 1, i+2)\n",
    "    plt.imshow(pred[0, i])\n",
    "    plt.title(model.targets[i])\n",
    "    plt.axis('off')\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0d10e22-2e2e-4e05-8008-0c890bd0ffc3",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Full Chest-XRay Datset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e912fb39-e2f2-460e-b00b-74f50924be8c",
   "metadata": {},
   "source": [
    "The following code generates a dataset with images + segmentation masks for the Chest Xray Dataset. The current label mapping is ideal for visually representing each class however, it is not ideal for training SPADEGAN. Either manually change label mapping to be 0-13 or use code after the next cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f684f984-d57a-46b6-95ad-41fcc939f898",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import skimage.io\n",
    "import torch\n",
    "import torchxrayvision as xrv\n",
    "import torchvision.transforms as transforms\n",
    "from skimage import img_as_ubyte  # Import this to convert images to 8-bit\n",
    "from skimage.exposure import rescale_intensity  # Import to rescale intensity\n",
    "\n",
    "# Set the dataset path and output directory\n",
    "data_dir = \"dataset\"  # Replace with your dataset path\n",
    "output_image_dir = \"output/path/images\"  # Directory to save the original images\n",
    "output_mask_dir = \"output/path/masks\"  # Directory to save the combined masks\n",
    "\n",
    "# Create the output directories if they don't exist\n",
    "os.makedirs(output_image_dir, exist_ok=True)\n",
    "os.makedirs(output_mask_dir, exist_ok=True)\n",
    "\n",
    "# Load the pre-trained PSPNet model from TorchXRayVision\n",
    "model = xrv.baseline_models.chestx_det.PSPNet()\n",
    "model.eval()  # Set the model to evaluation mode\n",
    "\n",
    "# Define the necessary transforms\n",
    "transform = transforms.Compose([\n",
    "    xrv.datasets.XRayCenterCrop(),\n",
    "    xrv.datasets.XRayResizer(512),\n",
    "])\n",
    "\n",
    "# Define label mapping for the combined mask\n",
    "labels = {\n",
    "    \"Left Clavicle\": 20,\n",
    "    \"Right Clavicle\": 40,\n",
    "    \"Left Scapula\": 60,\n",
    "    \"Right Scapula\": 80,\n",
    "    \"Left Lung\": 100,\n",
    "    \"Right Lung\": 120,\n",
    "    \"Left Hilus Pulmonis\": 140,\n",
    "    \"Right Hilus Pulmonis\": 160,\n",
    "    \"Heart\": 180,\n",
    "    \"Aorta\": 200,\n",
    "    \"Facies Diaphragmatica\": 220,\n",
    "    \"Mediastinum\": 240,\n",
    "    \"Weasand\": 255,\n",
    "    \"Spine\": 30  # Ensure distinct from others\n",
    "}\n",
    "\n",
    "# Process images from all subdirectories\n",
    "for root, dirs, files in os.walk(data_dir):\n",
    "    for file in files:\n",
    "        if file.endswith((\".png\", \".jpg\", \".jpeg\")):\n",
    "            # Load and preprocess the image\n",
    "            img_path = os.path.join(root, file)\n",
    "            img = skimage.io.imread(img_path)\n",
    "            img = xrv.datasets.normalize(img, 255)  # Normalize to [-1024, 1024]\n",
    "\n",
    "            if len(img.shape) == 2:  # If the image is grayscale (H, W)\n",
    "                img = img[None, ...]  # Add a channel dimension (1, H, W)\n",
    "\n",
    "            img = transform(img)\n",
    "            img_tensor = torch.from_numpy(img).unsqueeze(0)  # Add batch dimension\n",
    "\n",
    "            # Generate segmentation masks for each anatomical structure\n",
    "            with torch.no_grad():\n",
    "                pred = model(img_tensor)\n",
    "\n",
    "            # Apply sigmoid and thresholding\n",
    "            pred = 1 / (1 + np.exp(-pred))  # Sigmoid function\n",
    "            pred[pred < 0.5] = 0\n",
    "            pred[pred > 0.5] = 1\n",
    "\n",
    "            # Combine masks into one with distinct grayscale values\n",
    "            combined_mask = np.zeros(pred.shape[2:], dtype=np.uint8)\n",
    "            for i, label in enumerate(labels):\n",
    "                combined_mask[pred[0, i] > 0.5] = labels[label]\n",
    "\n",
    "            # Rescale the original image to [0, 1] before converting to 8-bit\n",
    "            img_rescaled = rescale_intensity(img.squeeze(), in_range=(-1024, 1024), out_range=(0, 1))\n",
    "            img_8bit = img_as_ubyte(img_rescaled)\n",
    "            combined_mask_8bit = img_as_ubyte(combined_mask)\n",
    "\n",
    "            # Create subdirectory in output directory corresponding to original image's subdirectory\n",
    "            rel_path = os.path.relpath(root, data_dir)\n",
    "            output_image_subdir = os.path.join(output_image_dir, rel_path)\n",
    "            output_mask_subdir = os.path.join(output_mask_dir, rel_path)\n",
    "            os.makedirs(output_image_subdir, exist_ok=True)\n",
    "            os.makedirs(output_mask_subdir, exist_ok=True)\n",
    "\n",
    "            # Save the original image and the combined mask\n",
    "            output_image_path = os.path.join(output_image_subdir, file)\n",
    "            output_mask_path = os.path.join(output_mask_subdir, f\"mask_{os.path.splitext(file)[0]}.png\")\n",
    "\n",
    "            skimage.io.imsave(output_image_path, img_8bit)\n",
    "            skimage.io.imsave(output_mask_path, combined_mask_8bit)\n",
    "\n",
    "            print(f\"Saved image: {output_image_path}\")\n",
    "            print(f\"Saved combined segmentation mask: {output_mask_path}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f74cc486-5cdf-4fd3-837f-0d215bf66468",
   "metadata": {},
   "source": [
    "The following code remaps the labels to be suitable for training SPADEGAN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2853dc9-a350-49de-8b14-8ae20677ebda",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from skimage.io import imread, imsave\n",
    "\n",
    "# Define the remapping dictionary\n",
    "remap_dict = {\n",
    "    0: 0,\n",
    "    20: 1,\n",
    "    30: 2,\n",
    "    40: 3,\n",
    "    60: 4,\n",
    "    80: 5,\n",
    "    100: 6,\n",
    "    120: 7,\n",
    "    140: 8,\n",
    "    160: 9,\n",
    "    180: 10,\n",
    "    220: 11,\n",
    "    240: 12,\n",
    "    255: 13\n",
    "}\n",
    "\n",
    "# Function to remap mask values\n",
    "def remap_mask(mask, remap_dict):\n",
    "    remapped_mask = np.copy(mask)\n",
    "    for old_value, new_value in remap_dict.items():\n",
    "        remapped_mask[mask == old_value] = new_value\n",
    "    return remapped_mask\n",
    "\n",
    "# Directory containing the masks\n",
    "mask_dir = \"path/to/masks\"\n",
    "output_mask_dir = \"path/to/masks_remapped\"\n",
    "os.makedirs(output_mask_dir, exist_ok=True)\n",
    "\n",
    "# Process each mask\n",
    "for filename in os.listdir(mask_dir):\n",
    "    if filename.endswith(\".png\"):  # Assuming masks are in PNG format\n",
    "        mask_path = os.path.join(mask_dir, filename)\n",
    "        mask = imread(mask_path)\n",
    "\n",
    "        # Remap the mask values\n",
    "        remapped_mask = remap_mask(mask, remap_dict)\n",
    "\n",
    "        # Save the remapped mask\n",
    "        output_path = os.path.join(output_mask_dir, filename)\n",
    "        imsave(output_path, remapped_mask)\n",
    "        print(f\"Saved remapped mask: {output_path}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
